{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Preparation for Fine-tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will show an example of the first step for fine-tuning: dataset preparation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "% pip install -U datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose we are willing to fine-tune our model for financial tasks. We found an open-source dataset that could be useful: [financial-qa-10k](https://huggingface.co/datasets/virattt/financial-qa-10K). Let's see how to properly prepare our dataset for fine-tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The raw dataset has the following structure:\n",
    "- 5 columns of: 'question', 'answer', 'context', 'ticker', and 'filing'.\n",
    "- 7000 rows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['question', 'answer', 'context', 'ticker', 'filing'],\n",
       "    num_rows: 7000\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data for Fine-tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Construct the dataset to the following format:\n",
    "\n",
    "``` python\n",
    "{\"query\": str, \"pos\": List[str], \"neg\":List[str], \"pos_scores\": List[int], \"neg_scores\": List[int], \"prompt\": str, \"type\": str}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts. `pos_scores` is a list of scores corresponding to the query and pos, `neg_scores` is a list of scores corresponding to the `query` and `neg`, if you don't use knowledge distillation, it can be ignored. `prompt` is the prompt used for the query, it will cover query_instruction_for_retrieval. `type` is used for bge-en-icl, it includes `normal`, `symmetric_class`, `symmetric_clustering`, .etc. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We select the columns 'question' and 'context' as our query and answer(pos), and rename the columns. Then add the 'id' column for later evaluation use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n",
       " 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n",
       " 'id': '0'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds = ds.select_columns(column_names=[\"question\", \"context\"])\n",
    "ds = ds.rename_column(\"question\", \"query\")\n",
    "ds = ds.rename_column(\"context\", \"pos\")\n",
    "ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n",
    "ds[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Negative examples are important during the training of embedding models. Our initial dataset does not come with negative texts. Thus we directly sample a few from the whole corpus."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7000/7000 [00:00<00:00, 22336.83 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.random.seed(520)\n",
    "neg_num = 10\n",
    "\n",
    "def str_to_lst(data):\n",
    "    data[\"pos\"] = [data[\"pos\"]]\n",
    "    return data\n",
    "\n",
    "# sample negative texts\n",
    "new_col = []\n",
    "for i in range(len(ds)):\n",
    "    ids = np.random.randint(0, len(ds), size=neg_num)\n",
    "    while i in ids:\n",
    "        ids = np.random.randint(0, len(ds), size=neg_num)\n",
    "    neg = [ds[i.item()][\"pos\"] for i in ids]\n",
    "    new_col.append(neg)\n",
    "ds = ds.add_column(\"neg\", new_col)\n",
    "\n",
    "# change the key of 'pos' to a list\n",
    "ds = ds.map(str_to_lst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we add the prompt which is used for query. It will be the `query_instruction_for_retrieval` during inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "instruction = \"Represent this sentence for searching relevant passages: \"\n",
    "ds = ds.add_column(\"prompt\", [instruction]*len(ds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now a single row of the dataset is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n",
       " 'pos': ['Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.'],\n",
       " 'id': '0',\n",
       " 'neg': ['Kroger expects that its value creation model will deliver total shareholder return within a target range of 8% to 11% over time.',\n",
       "  'CSB purchased First Mortgages of $2.9 billion during 2023.',\n",
       "  'See Note 13 to our Consolidated Financial Statements for information on certain legal proceedings for which there are contingencies.',\n",
       "  'Diluted earnings per share were $16.69 in fiscal 2022 compared to $15.53 in fiscal 2021.',\n",
       "  'In the year ended December 31, 2023, Total net sales and revenue increased primarily due to: (1) increased net wholesale volumes primarily due to increased sales of crossover vehicles and full-size pickup trucks, partially offset by decreased sales of mid-size pickup trucks; (2) favorable Price as a result of low dealer inventory levels and strong demand for our products; (3) favorable Mix associated with increased sales of full-size pickup trucks and full-size SUVs and decreased sales of vans, passenger cars and mid-size pickup trucks, partially offset by increased sales of crossover vehicles; and (4) favorable Other due to increased sales of parts and accessories.',\n",
       "  'As of December 31, 2023, we had 3,157 full-time employees.',\n",
       "  'Item 3. Legal Proceedings. The information contained in Note 18 ‘‘Commitments and Contingencies’’ included in Item 8 of this 10-K is incorporated herein by reference.',\n",
       "  'Under the amended 2019 Secured Facility, the maturity date is set to July 20, 2026.',\n",
       "  'Accounts receivable for Las Vegas Sands Corp. on December 31, 2023, totaled $685 million, with a provision for credit losses of $201 million, resulting in a net balance of $484 million.',\n",
       "  'Operating expenses as a percentage of segment net sales decreased 25 basis points for fiscal 2023 when compared to the previous fiscal year, primarily driven by strong sales growth and lower incremental COVID-19 related costs, partially offset by increased wage costs.'],\n",
       " 'prompt': 'Represent this sentence for searching relevant passages: '}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we split the dataset into training set and testing set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n",
    "train = split[\"train\"]\n",
    "test = split[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to store the data for later fine-tuning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 39.73ba/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "16583481"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.to_json(\"ft_data/training.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Test Data for Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last step is to construct the testing dataset for evaluaton."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['query', 'pos', 'id', 'neg', 'prompt'],\n",
       "    num_rows: 700\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First select the columns for queries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': '1289',\n",
       " 'text': 'How does Starbucks recognize the interest and penalties related to income tax matters on their financial statements?'}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries = test.select_columns(column_names=[\"id\", \"query\"])\n",
    "queries = queries.rename_column(\"query\", \"text\")\n",
    "queries[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then select the columns for corpus:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = ds.select_columns(column_names=[\"id\", \"pos\"])\n",
    "corpus = corpus.rename_column(\"pos\", \"text\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, make the qrels that indicating the relations of queries and corresponding corpus\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Flattening the indices: 100%|██████████| 700/700 [00:00<00:00, 180956.10 examples/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'qid': '1289', 'docid': '1289', 'relevance': 1}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qrels = test.select_columns([\"id\"])\n",
    "qrels = qrels.rename_column(\"id\", \"qid\")\n",
    "qrels = qrels.add_column(\"docid\", list(test[\"id\"]))\n",
    "qrels = qrels.add_column(\"relevance\", [1]*len(test))\n",
    "qrels[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Store the training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 210.42ba/s]\n",
      "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 261.19ba/s]\n",
      "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 591.08ba/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "30574"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries.to_json(\"ft_data/test_queries.jsonl\")\n",
    "corpus.to_json(\"ft_data/corpus.jsonl\")\n",
    "qrels.to_json(\"ft_data/test_qrels.jsonl\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
